Inspired by Julia Silge

We are building a model from this week’s #tidytuesday data set on volcano eruptions. What we are looking for will be building a multiclass random forest classifier to predict the type of volcano based on other volcano characteristics like latitude, longitude, tectonic setting, etc.

Explore the data

volcano_raw <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-12/volcano.csv')

volcano_raw %>% 
  count(primary_volcano_type, sort = TRUE)
## # A tibble: 26 x 2
##    primary_volcano_type     n
##    <chr>                <int>
##  1 Stratovolcano          353
##  2 Stratovolcano(es)      107
##  3 Shield                  85
##  4 Volcanic field          71
##  5 Pyroclastic cone(s)     70
##  6 Caldera                 65
##  7 Complex                 46
##  8 Shield(s)               33
##  9 Submarine               27
## 10 Lava dome(s)            26
## # … with 16 more rows

We only have <1000 data points yet 26 types, so we won’t be looking at all the different types or compress the similar types. We will only consider 3 major types: - stratovolcano - shield volcano - the rest (other)

volcano_df <- volcano_raw %>%
  transmute(
    volcano_type = case_when(
      str_detect(primary_volcano_type, "Stratovolcano") ~ "Stratovolcano",
      str_detect(primary_volcano_type, "Shield") ~ "Shield",
      TRUE ~ "Other"
    ),
    volcano_number, latitude, longitude, elevation,
    tectonic_settings, major_rock_1
  ) %>%
  mutate_if(is.character, factor)

volcano_df %>%
  count(volcano_type, sort = TRUE)
## # A tibble: 3 x 2
##   volcano_type      n
##   <fct>         <int>
## 1 Stratovolcano   461
## 2 Other           379
## 3 Shield          118

When we have spatial information, it’s always good to display it in a map.

world <- map_data("world")

ggplot() +
  geom_map(data = world, map = world,
           aes(long, lat, map_id = region),
           color = "white", fill = "gray50", alpha = 0.2) + 
  geom_point(data = volcano_df,
             aes(longitude, latitude, color = volcano_type),
             alpha = 0.8)

Build a model

library(tidymodels)
volcano_boot <- bootstraps(volcano_df, times = 500)
volcano_boot
## # Bootstrap sampling 
## # A tibble: 500 x 2
##    splits            id          
##    <list>            <chr>       
##  1 <split [958/348]> Bootstrap001
##  2 <split [958/353]> Bootstrap002
##  3 <split [958/362]> Bootstrap003
##  4 <split [958/352]> Bootstrap004
##  5 <split [958/356]> Bootstrap005
##  6 <split [958/354]> Bootstrap006
##  7 <split [958/356]> Bootstrap007
##  8 <split [958/348]> Bootstrap008
##  9 <split [958/353]> Bootstrap009
## 10 <split [958/352]> Bootstrap010
## # … with 490 more rows

themis package is typically used for imbalanced datasets, particularly using step_smote. Some features have too many values for such a small dataset and thus, we need to apply step_other for these variables.

library(themis)

volcano_rec <- recipe(volcano_type ~ ., data = volcano_df) %>% 
  update_role(volcano_number, new_role = "Id") %>% 
  step_other(tectonic_settings) %>% 
  step_other(major_rock_1) %>% 
  step_dummy(tectonic_settings, major_rock_1) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors()) %>% 
  step_smote(volcano_type)

volcano_prep <- prep(volcano_rec)
rf_spec <- rand_forest(trees = 1000) %>% 
  set_mode("classification") %>% 
  set_engine("ranger")

volcano_wf <- workflow() %>% 
  add_recipe(volcano_rec) %>% 
  add_model(rf_spec)

volcano_wf
## ══ Workflow ═══════════════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
## 
## ── Preprocessor ───────────────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## ● step_other()
## ● step_other()
## ● step_dummy()
## ● step_zv()
## ● step_normalize()
## ● step_smote()
## 
## ── Model ──────────────────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   trees = 1000
## 
## Computational engine: ranger

For any individual bootstrap, each recipe is evaluated and then the model as well. Thereafter, predictions are made.

volcano_res <- fit_resamples(
  volcano_wf,
  resamples = volcano_boot,
  control = control_resamples(save_pred = TRUE,
                              verbose = TRUE)
)

Explore results

How good our model is?

volcano_res %>% 
  collect_metrics()
## # A tibble: 2 x 5
##   .metric  .estimator  mean     n std_err
##   <chr>    <chr>      <dbl> <int>   <dbl>
## 1 accuracy multiclass 0.653    25 0.00332
## 2 roc_auc  hand_till  0.791    25 0.00378
volcano_res %>% 
  collect_predictions() %>% 
  conf_mat(volcano_type, .pred_class)
##                Truth
## Prediction      Other Shield Stratovolcano
##   Other          1938    328           746
##   Shield          274    552           211
##   Stratovolcano  1273    212          3247
volcano_res %>% 
  collect_predictions() %>% 
  group_by(id) %>% 
  ppv(volcano_type, .pred_class) %>% 
  ggplot(aes(.estimate)) + geom_histogram(bins = 10)

Now, we are fitting the model into the original data to see features importance.

library(vip)

rf_spec %>% 
  set_engine("ranger", importance = "permutation") %>% 
  fit(volcano_type ~ .,
      data = juice(volcano_prep) %>% 
        select(-volcano_number) %>% 
        janitor::clean_names()) %>% 
  vip(geom = "point")

From this plot above, longitude and latitude are the biggest factors to affect volcano. Major rock basalt picro basalt is the next biggest impact on the prediction model. Given longitude and latitude as the most important features in this model, we can display this in a map again to see where our model predicts better. Furthermore, we can facet_wrap according to volcano type to observe which type and where they’re predicted better by the model.

volcano_pred <- volcano_res %>% 
  collect_predictions() %>% 
  mutate(correct = volcano_type == .pred_class) %>% 
  left_join(volcano_df %>% mutate(.row = row_number()))
ggplot() +
  geom_map(data = world, map = world,
           aes(long, lat, map_id = region),
           color = "white", fill = "gray50", alpha = 0.2) + 
  stat_summary_hex(data = volcano_pred,
                  aes(longitude, latitude, z = correct),
                  fun = "mean", 
                  alpha = 0.5, 
                  size = 0.5, bins = 60) +
  scale_fill_gradient(high = "cyan3", labels = scales::percent) +
  labs(x = "longitude", y = "latitude", fill = "Percent classified\ncorrectly")

ggplot() +
  geom_map(data = world, map = world,
           aes(long, lat, map_id = region),
           color = "white", fill = "gray50", alpha = 0.2) + 
  stat_summary_hex(data = volcano_pred,
                  aes(longitude, latitude, z = correct),
                  fun = "mean", 
                  alpha = 0.5, 
                  size = 0.5, bins = 60) +
  facet_wrap(~ volcano_type) + 
  scale_fill_gradient(high = "cyan3", labels = scales::percent) +
  labs(x = "longitude", y = "latitude", fill = "Percent classified\ncorrectly")